
Disclaimer: â ď¸ Do not use this code to determine what to put in your omelette! đł
Itâs important to note that the identification and classification of mushrooms should be done carefully, as some mushrooms can be toxic or deadly. Always consult with a mycologist or use a reliable field guide when identifying mushrooms
đĄ đ˘ âď¸ remember to read the readme.md file for helpful hints on the best ways to view/navigate this project
If you visualize this notebook on github you will be missing important content
Some charts/diagrams/features are not visible in github. This is standard and well-known behaviour.
Consider viewing the pre-rendered HTML files, or run all notebooks end to end after enabling the feature flags that control long running operations:
If you chose to run this locally, there are some prerequisites:
python 3.9pip install -r requirements.txt before proceeding.Sprint 1: Computer Vision - Know Your Mushrooms
# Background
In US alone, around 7500 yearly cases of mushrooms poisoning are reported [(Source)](https://www.tandfonline.com/doi/full/10.1080/00275514.2018.1479561). According to the source, "misidentification of edible mushroom species appears to be the most common cause and may be preventable through education". To avoid expenses for hospitalization and in some cases pointless deaths, you have been hired by US National Health Service to create a machine learning model, that can recognize mushroom types. They want to install this on hand-held devices and to help people make the right choice when mushroom picking.
# Concepts to explore
Today, we will put everything we learned in this module and use it to solve a classification problem. The idea of this project is to use transfer learning on an architecture of your choice and fine-tune to predict mushroom types.
You will use this Kaggle dataset <https://www.kaggle.com/maysee/mushrooms-classification-common-genuss-images>
# How to start?
## Data
Well, the obvious first steps will be getting the data from Kaggle. There are a number of choices on how to do it, such as downloading images to your machine and then uploading to Drive or using [Kaggle API](https://github.com/Kaggle/kaggle-api). Once you get your data, start with an EDA, as this will directly feed into design choices for your architecture.
## Modeling
My suggestion is that you start with a simple pre-trained architecture, like ResNet18. This will allow you to fine-tune your net faster and if results are not too good, you can try switching to a larger model later. It is recommended that you use PyTorch Lightning or FastAI. Both are equally good for simple problems like this, but PyTorch Lightning will give you more control, better customization ability, and better understanding of your network.
# Requirements
- Choose whichever framework you prefer from FastAI, PyTorch Lightning or PyTorch.
- As always - EDA
- Use a pre-trained neural net as a backbone of your class
- Train a classifier. Don't forget to fine-tune
- Evaluate inference time
- Visualize results
# Evaluation Criteria
- Model performance
- Classification performance
- Inference speed
- EDA and documented findings
- Results analysis
- Code quality
# Bonus challenges
- Repeat the process with modifications to your network and see how the results vary.
- Try a different optimizer
- Add an intermediate layer between the backbone and output layer
# Sample correction questions
During a correction, you may get asked questions that test your understanding of covered topics.
- Describe how a convolutional layers works
- What is overfitting? Describe why is it bad/good and how to detect it?
- What is an optimizer? Describe in high level how it works.
- What are the advantages/disadvantages of transfer learning?
đĄ After receiving feedback in past project that I tend to murder my reviewers by going too in depth, with excessive detail and explaining too many things, I will try to keep the descriptions for "what I'm thinking/considering" a bit more concise. Happy to hear your feedback.
from IPython.display import display, Markdown, clear_output, HTML, IFrame
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import itertools
import glob
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms, datasets, models
import torch.nn.functional as F
from torchmetrics.classification import (
MulticlassConfusionMatrix,
MulticlassAccuracy,
MulticlassPrecision,
MulticlassRecall,
MulticlassF1Score,
)
from torchmetrics import ConfusionMatrix
# import torchvision.models as models
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from optuna.integration import PyTorchLightningPruningCallback
from PIL import Image, ImageDraw, ImageFile
import albumentations as A
import plotly.express as px
import plotly.io as pio
from scipy import stats
from scipy.stats import chi2_contingency
import missingno as msno
from sklearn.metrics import ConfusionMatrixDisplay
from random import random, seed, shuffle
import logging
import warnings
import os
from os import path
from utils import *
from utils import __
loading utils modules...
/home/edu/anaconda3/envs/py39_lab4/lib/python3.9/site-packages/xgboost/compat.py:93: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead. from pandas import MultiIndex, Int64Index
â completed configuring autoreload... â completed
from keras.utils import plot_model
from keras.applications.resnet50 import ResNet50
2023-12-10 09:43:11.636571: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2023-12-10 09:43:11.654057: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-10 09:43:11.654071: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-10 09:43:11.654700: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-12-10 09:43:11.657998: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-12-10 09:43:12.044040: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
print(f"{np.__version__ = }")
print(f"{pd.__version__ = }")
print(f"{torch.__version__ = }")
print(f"{pl.__version__ = }")
np.__version__ = '1.23.5' pd.__version__ = '1.5.3' torch.__version__ = '2.1.0+cu121' pl.__version__ = '2.1.1'
seed(100)
pd.options.display.max_rows = 100
pd.options.display.max_colwidth = 50
pio.renderers.default = "png"
ImageFile.LOAD_TRUNCATED_IMAGES = True
util.check("done")
â
%reload_ext mushrooms_utils
import mushrooms_utils as mushroom
Let's use black to auto-format all our cells so they adhere to PEP8
import lab_black
%reload_ext lab_black
util.patch_nb_black()
# fmt: off
# fmt: on
from sklearn import set_config
set_config(transform_output="pandas")
sns.set_theme(context="notebook", style="whitegrid")
moonstone = "#62b6cb"
moonstone_rgb = util.hex_to_rgb(moonstone)
moonstone_rgb_n = np.array(moonstone_rgb) / 255
logger = util.configure_logging(jupyterlab_level=logging.WARN, file_level=logging.DEBUG)
warnings.filterwarnings("ignore", category=FutureWarning)
# import warnings
# warnings.filterwarnings('error', category=pd.errors.DtypeWarning)
def ding(title="Ding!", message="Task completed"):
"""
this method only works on linux
"""
for i in range(2):
!notify-send '{title}' '{message}'
Let's also create a simple feature toggle that we can use to skip expensive operations during notebook work (to save myself some time!)
Set it to true if you want to run absolutely everything. Set to false to skip optional steps/exploratory work.
def run_entire_notebook(filename: str = None):
run_all = False
if not run_all:
print("skipping optional operation")
fullpath = f"cached/printouts/{filename}.txt"
if filename is not None and os.path.exists(fullpath):
print("==== đď¸ printing cached output ====")
with open(fullpath) as f:
print(f.read())
return run_all
kaggle_dataset_name = "maysee/mushrooms-classification-common-genuss-images"
db_filename = "Mushrooms"
auto_kaggle.download_dataset(kaggle_dataset_name, db_filename, timeout_seconds=3 * 60)
__
Kaggle API 1.5.13 - login as 'edualmas' File [dataset/Mushrooms] already exists locally! No need to re-download data for dataset [maysee/mushrooms-classification-common-genuss-images]
A quick check shows that the dataset, has 2 identical folders, with the data duplicated.
We will ignore one of the directories altogether

if os.path.exists("dataset/mushrooms/"):
!rm -rf dataset/mushrooms/
print("removed duplicated images")
else:
print("folder with duplicated images has already been removed")
folder with duplicated images has already been removed
We need to split our dataset into a few chunks.
Since there seems to be no standard way out of the box to achieve all of these things:
We have to do the split ourselves.
* = The pytorch SubSets are basically a view of the original dataset and they all share the same transforms. This would be normal, but we want to apply different preprocessing to different splits:
So, here we go again, reinventing the wheel.
if not os.path.exists("dataset/Mushrooms_split/"):
operation_log = split_utils.split_image_dataset(
"dataset/Mushrooms",
{"train": 0.65, "hp": 0.10, "val": 0.10, "test": 0.15},
glob_filename_pattern="*.jpg",
)
display(operation_log.head())
assert operation_log["source_file"].duplicated().any() == False
else:
print("dataset is already split in [dataset/Mushrooms_split]")
dataset is already split in [dataset/Mushrooms_split]
PASS: each file has only been copied once â
No data leakage â
Let's also check that we lost no data:
with fs_utils.in_subdir("dataset/"):
!find Mushrooms -type d -exec sh -c 'echo files: "$(find "{}" -type f | wc -l) \t {}"' \;
files: 6714 Mushrooms files: 1073 Mushrooms/Boletus files: 353 Mushrooms/Agaricus files: 750 Mushrooms/Amanita files: 1563 Mushrooms/Lactarius files: 311 Mushrooms/Suillus files: 836 Mushrooms/Cortinarius files: 364 Mushrooms/Entoloma files: 316 Mushrooms/Hygrocybe files: 1148 Mushrooms/Russula
with fs_utils.in_subdir("dataset/"):
!find Mushrooms_split -type d -exec sh -c 'echo files: "$(find "{}" -type f | wc -l) \t {}"' \;
files: 6714 Mushrooms_split files: 4364 Mushrooms_split/train files: 697 Mushrooms_split/train/Boletus files: 230 Mushrooms_split/train/Agaricus files: 488 Mushrooms_split/train/Amanita files: 1016 Mushrooms_split/train/Lactarius files: 202 Mushrooms_split/train/Suillus files: 543 Mushrooms_split/train/Cortinarius files: 237 Mushrooms_split/train/Entoloma files: 205 Mushrooms_split/train/Hygrocybe files: 746 Mushrooms_split/train/Russula files: 1008 Mushrooms_split/test files: 162 Mushrooms_split/test/Boletus files: 53 Mushrooms_split/test/Agaricus files: 112 Mushrooms_split/test/Amanita files: 235 Mushrooms_split/test/Lactarius files: 47 Mushrooms_split/test/Suillus files: 125 Mushrooms_split/test/Cortinarius files: 55 Mushrooms_split/test/Entoloma files: 47 Mushrooms_split/test/Hygrocybe files: 172 Mushrooms_split/test/Russula files: 671 Mushrooms_split/val files: 107 Mushrooms_split/val/Boletus files: 35 Mushrooms_split/val/Agaricus files: 75 Mushrooms_split/val/Amanita files: 156 Mushrooms_split/val/Lactarius files: 31 Mushrooms_split/val/Suillus files: 84 Mushrooms_split/val/Cortinarius files: 36 Mushrooms_split/val/Entoloma files: 32 Mushrooms_split/val/Hygrocybe files: 115 Mushrooms_split/val/Russula files: 671 Mushrooms_split/hp files: 107 Mushrooms_split/hp/Boletus files: 35 Mushrooms_split/hp/Agaricus files: 75 Mushrooms_split/hp/Amanita files: 156 Mushrooms_split/hp/Lactarius files: 31 Mushrooms_split/hp/Suillus files: 84 Mushrooms_split/hp/Cortinarius files: 36 Mushrooms_split/hp/Entoloma files: 32 Mushrooms_split/hp/Hygrocybe files: 115 Mushrooms_split/hp/Russula
Let's take a look at the raw dataset that we have retrieved from kaggle
import matplotlib.pyplot as plt
from PIL import Image
import os
@run
@cached_chart(extension="jpg")
def mushroom_pics_sample():
# yes, "genera" would be a more scientifically accurate name...
# but we're aiming for code that is easy to read by
# a non-scientific audience (us, developers!)
all_genuses = mushroom.Genus
imgs_per_genus = 10
f, ax = plt.subplots(len(all_genuses), imgs_per_genus, figsize=(15, 15))
for g, genus in enumerate(all_genuses):
image_dir = f"dataset/Mushrooms_split/train/{genus.value.dir_name}"
images = (
os.path.join(image_dir, file)
for file in os.listdir(image_dir)
if file.endswith(".jpg")
)
image_files = list(itertools.islice(images, 10))
for i, image_file in enumerate(image_files):
dirname = os.path.basename(os.path.dirname(image_file))
filename = os.path.basename(image_file)
image = Image.open(image_file)
ax[g, i].imshow(image)
ax[g, i].axis("off")
ax[g, i].set_title(dirname + "/\n" + filename[:9] + "...")
plt.tight_layout()
plt.show()
return f
Loading from cache [./cached/charts/mushroom_pics_sample.jpg]
A few thoughts:
Let's take a closer look at a few pics to understand how they are structured
@run
@cached_chart(extension="jpg")
def mushroom_pics_sample_1():
f = plt.figure()
pic1 = "dataset/Mushrooms_split/train/Agaricus/102_BV5Swi4Xfjc.jpg"
image1 = Image.open(pic1)
plt.imshow(image1)
return f
Loading from cache [./cached/charts/mushroom_pics_sample_1.jpg]
We can understand a picture as a matrix of pixels.
We can also retrieve the 3 colour channels independently.
@run
@cached_chart(extension="jpg")
def mushroom_pics_sample_2():
f = plt.figure()
pic2 = "dataset/Mushrooms_split/train/Cortinarius/070_hnlGwobiKIs.jpg"
image2 = Image.open(pic2)
plt.imshow(image2.resize((280, 300)))
charts.remove_axes()
return f
# SVD ~= PCA for removing noise/extra info from images
Loading from cache [./cached/charts/mushroom_pics_sample_2.jpg]
@run
@cached_chart(extension="jpg")
def mushroom_pics_sample_3():
f, ax = plt.subplots(2, 3, figsize=(15, 9))
pic2 = "dataset/Mushrooms_split/train/Cortinarius/070_hnlGwobiKIs.jpg"
image2 = Image.open(pic2)
ax[0, 0].set_title("red channel...")
ax[0, 1].set_title("green channel...")
ax[0, 2].set_title("blue channel..")
ax[0, 0].imshow(np.array(image2)[:, :, 0], cmap="Reds_r")
ax[0, 1].imshow(np.array(image2)[:, :, 1], cmap="Greens_r")
ax[0, 2].imshow(np.array(image2)[:, :, 2], cmap="Blues_r")
ax[1, 0].set_title("...in grayscale")
ax[1, 1].set_title("...in grayscale")
ax[1, 2].set_title("...in grayscale")
ax[1, 0].imshow(np.array(image2)[:, :, 0], cmap="Greys_r")
ax[1, 1].imshow(np.array(image2)[:, :, 1], cmap="Greys_r")
ax[1, 2].imshow(np.array(image2)[:, :, 2], cmap="Greys_r")
return f
Loading from cache [./cached/charts/mushroom_pics_sample_3.jpg]
Let's take a look at the classes that we have and how are they distributed/shaped
classes = glob.glob("dataset/Mushrooms_split/train/*")
classes = {
path.replace("dataset/Mushrooms_split/train/", "").lower(): path for path in classes
}
classes
{'boletus': 'dataset/Mushrooms_split/train/Boletus',
'agaricus': 'dataset/Mushrooms_split/train/Agaricus',
'amanita': 'dataset/Mushrooms_split/train/Amanita',
'lactarius': 'dataset/Mushrooms_split/train/Lactarius',
'suillus': 'dataset/Mushrooms_split/train/Suillus',
'cortinarius': 'dataset/Mushrooms_split/train/Cortinarius',
'entoloma': 'dataset/Mushrooms_split/train/Entoloma',
'hygrocybe': 'dataset/Mushrooms_split/train/Hygrocybe',
'russula': 'dataset/Mushrooms_split/train/Russula'}
files = {f: c for c, path in classes.items() for f in glob.glob(f"{path}/*.jpg")}
df_files = pd.DataFrame({"path": files.keys(), "genus": files.values()})
df_files.sample(n=10)
| path | genus | |
|---|---|---|
| 2764 | dataset/Mushrooms_split/train/Cortinarius/105_... | cortinarius |
| 1120 | dataset/Mushrooms_split/train/Amanita/079_6ZA8... | amanita |
| 2537 | dataset/Mushrooms_split/train/Suillus/155_hPsV... | suillus |
| 4325 | dataset/Mushrooms_split/train/Russula/617_B_F5... | russula |
| 3846 | dataset/Mushrooms_split/train/Russula/036_7WKO... | russula |
| 2170 | dataset/Mushrooms_split/train/Lactarius/410_gH... | lactarius |
| 1578 | dataset/Mushrooms_split/train/Lactarius/249_4V... | lactarius |
| 4306 | dataset/Mushrooms_split/train/Russula/302_Yq1G... | russula |
| 1772 | dataset/Mushrooms_split/train/Lactarius/1079_D... | lactarius |
| 2509 | dataset/Mushrooms_split/train/Suillus/097_JAf_... | suillus |
def get_picture_size(path):
with Image.open(path) as img:
width, height = img.size
return width, height
def get_picture_height(path):
return get_picture_size(path)[1]
def get_picture_width(path):
return get_picture_size(path)[0]
df_files["height"] = df_files["path"].map(get_picture_height)
df_files["width"] = df_files["path"].map(get_picture_width)
df_files
| path | genus | height | width | |
|---|---|---|---|---|
| 0 | dataset/Mushrooms_split/train/Boletus/0201_PO7... | boletus | 693 | 960 |
| 1 | dataset/Mushrooms_split/train/Boletus/0048_wGU... | boletus | 567 | 800 |
| 2 | dataset/Mushrooms_split/train/Boletus/0112_s0d... | boletus | 600 | 800 |
| 3 | dataset/Mushrooms_split/train/Boletus/0664_aAv... | boletus | 404 | 570 |
| 4 | dataset/Mushrooms_split/train/Boletus/1089_w6Z... | boletus | 517 | 800 |
| ... | ... | ... | ... | ... |
| 4359 | dataset/Mushrooms_split/train/Russula/342_1iB1... | russula | 585 | 780 |
| 4360 | dataset/Mushrooms_split/train/Russula/044_OC94... | russula | 600 | 800 |
| 4361 | dataset/Mushrooms_split/train/Russula/361_IK7M... | russula | 533 | 800 |
| 4362 | dataset/Mushrooms_split/train/Russula/170_1oaK... | russula | 535 | 800 |
| 4363 | dataset/Mushrooms_split/train/Russula/209_9N--... | russula | 600 | 800 |
4364 rows Ă 4 columns
@run
@cached_chart()
def image_count_by_genus():
f = plt.figure()
order = df_files["genus"].value_counts().index
sns.countplot(data=df_files, y="genus", order=order, color=moonstone)
plt.title("count of raw images per genus")
return f
Loading from cache [./cached/charts/image_count_by_genus.png]
Our aim is to have a few hundreds of examples per class, from the training split.
This means that:
For now we will continue with EDA and this resampling/rebalancing will be down later.
Let's take a look at the image resolutions and how they are distributed
resolutions = df_files.groupby("genus")["width", "height"].agg(["min", "max"])
resolutions
| width | height | |||
|---|---|---|---|---|
| min | max | min | max | |
| genus | ||||
| agaricus | 423 | 1200 | 282 | 906 |
| amanita | 275 | 1280 | 183 | 1024 |
| boletus | 262 | 1200 | 192 | 948 |
| cortinarius | 259 | 1280 | 152 | 1024 |
| entoloma | 275 | 1210 | 183 | 921 |
| hygrocybe | 528 | 1200 | 370 | 931 |
| lactarius | 391 | 1280 | 280 | 1024 |
| russula | 400 | 1200 | 300 | 935 |
| suillus | 259 | 1280 | 184 | 961 |
print(" ", "width", "\t", "height")
print("min", df_files["width"].min(), "\t", df_files["height"].min())
print("max", df_files["width"].max(), "\t", df_files["height"].max())
width height min 259 152 max 1280 1024
fig = px.density_heatmap(
df_files,
x="height",
y="width",
marginal_x="histogram",
marginal_y="histogram",
color_continuous_scale=[
(0, "white"),
(0.01, "lightgrey"),
(1, moonstone),
],
)
inner_ratio = df_files["width"].max() / df_files["height"].max()
marginal_ratio = 0.2
fig_width = 800
fig_height = fig_width / (inner_ratio + marginal_ratio)
fig.update_xaxes(range=[0, df_files["width"].max()])
fig.update_yaxes(range=[0, df_files["height"].max()])
fig.update_layout(
plot_bgcolor="white",
yaxis=dict(autorange="reversed"),
autosize=False,
width=fig_width,
height=fig_height,
title="heatmap of resolutions",
)
fig.show()
df_files[df_files["height"] > 1000]
| path | genus | height | width | |
|---|---|---|---|---|
| 1400 | dataset/Mushrooms_split/train/Amanita/382_HSyQ... | amanita | 1024 | 970 |
| 1502 | dataset/Mushrooms_split/train/Lactarius/414_JR... | lactarius | 1024 | 1112 |
| 1511 | dataset/Mushrooms_split/train/Lactarius/0279_D... | lactarius | 1024 | 768 |
| 1589 | dataset/Mushrooms_split/train/Lactarius/418_kI... | lactarius | 1024 | 1056 |
| 1623 | dataset/Mushrooms_split/train/Lactarius/0903_g... | lactarius | 1024 | 1273 |
| 2284 | dataset/Mushrooms_split/train/Lactarius/0129_o... | lactarius | 1024 | 682 |
| 2742 | dataset/Mushrooms_split/train/Cortinarius/070_... | cortinarius | 1024 | 940 |
df_files["w_group"] = df_files["width"] // 100
df_files["h_group"] = df_files["height"] // 100
Let's cluster images in buckets of 100x100 pixels
@run
@cached_chart()
def resolution_clusters():
heatmap_scaling = 0.6
f = plt.figure(figsize=(12 * heatmap_scaling, 10 * heatmap_scaling))
resolution_groups = (
df_files[["w_group", "h_group", "path"]]
.pivot_table(index="h_group", columns="w_group", aggfunc="count")
.droplevel(0, axis=1)
)
sns.heatmap(resolution_groups, cmap="Blues", annot=True, fmt=".0f")
plt.title("resolution clusters")
plt.xlabel("width x 100")
plt.ylabel("height x 100")
return f
Loading from cache [./cached/charts/resolution_clusters.png]
It seems that most images are around the SVGA range 800x800 (or ~ 800x600). These are a good starting point: enough info to discern figures, but not unnecessarily large.
This project requires us to use transfer learning (by starting from a pre-trained model), so we might have to scale them somewhat, but at least they don't seem to suffer from major issues (not too large/small).
def valid_image_file(filename: str) -> bool:
# ImageFolder() requires the signature to be:
# Callable[[str], bool]
valid_last_bytes = {
".jpg": b"\xff\xd9",
".png": b"\x60\x82",
# ".gif": b"\x3b",
}
extension = filename[-4:]
if extension not in valid_last_bytes.keys():
raise ValueError(f"File extension is unknown. {extension = }")
with open(filename, mode="rb") as image_file:
file_content = image_file.read()
last_bytes = file_content[-2:]
return last_bytes == valid_last_bytes[extension]
just_resize = transforms.Compose(
[
transforms.Resize((224, 224)), # resnet50 requires this!
transforms.ToTensor(),
]
)
raw_dataset = datasets.ImageFolder(
"dataset/Mushrooms_split/train/",
transform=just_resize,
is_valid_file=valid_image_file,
)
Some of the steps and transformations require us to normalize our images.
We want to calculate some metadata/metrics from our training split so we can apply some transformations/normalization that we hope will help our model.
def calculate_mean_std_for_split(dloader):
mean = 0.0
std = 0.0
nb_samples = 0.0
# Calculate the mean and std on the training set only
for data, _ in dloader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples
print(f"Mean: {mean}")
print(f"Std: {std}")
return mean, std
train_mean, train_std = calculate_mean_std_for_split(DataLoader(raw_dataset))
Mean: tensor([0.3905, 0.3685, 0.2805]) Std: tensor([0.2296, 0.2095, 0.2032])
We have the mean/std for the raw images in the train split. We can use this to normalize our data for all 4 dataset.
Now that we know the mean/std values for our training images, we can create the data loaders for all of the splits, and we can configure them correctly.
Remember:
normalize = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist()),
]
)
image_augmentation = transforms.Compose(
[
transforms.RandomResizedCrop(800),
transforms.Resize((400, 400)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.RandomAffine(degrees=100),
transforms.GaussianBlur(3),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist()),
]
)
train_dataset = datasets.ImageFolder(
"dataset/Mushrooms_split/train/",
transform=image_augmentation,
is_valid_file=valid_image_file,
)
hp_dataset = datasets.ImageFolder(
"dataset/Mushrooms_split/hp/",
transform=normalize,
is_valid_file=valid_image_file,
)
val_dataset = datasets.ImageFolder(
"dataset/Mushrooms_split/val/",
transform=normalize,
is_valid_file=valid_image_file,
)
test_dataset = datasets.ImageFolder(
"dataset/Mushrooms_split/test/",
transform=normalize,
is_valid_file=valid_image_file,
)
# using 15 workers; the value suggested by pytorch in its warning
data_load_settings = {"batch_size": 128, "num_workers": 15}
train_loader = DataLoader(
train_dataset, pin_memory=True, shuffle=True, **data_load_settings
)
hp_loader = DataLoader(hp_dataset, shuffle=False, **data_load_settings)
val_loader = DataLoader(val_dataset, shuffle=False, **data_load_settings)
test_loader = DataLoader(test_dataset, shuffle=False, **data_load_settings)
train_dataset
Dataset ImageFolder
Number of datapoints: 4363
Root location: dataset/Mushrooms_split/train/
StandardTransform
Transform: Compose(
RandomResizedCrop(size=(800, 800), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=warn)
Resize(size=(400, 400), interpolation=bilinear, max_size=None, antialias=warn)
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
RandomAffine(degrees=[0.0, 0.0], translate=(0.1, 0.1))
RandomAffine(degrees=[-100.0, 100.0])
GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))
Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
ToTensor()
Normalize(mean=[0.3905445635318756, 0.3684823215007782, 0.2804551124572754], std=[0.22958891093730927, 0.20953883230686188, 0.2031669169664383])
)
Let's inspect that the images have been transformed as expected:
For Training:
images, labels = next(iter(train_loader))
images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))
for i, image in enumerate(images):
cax = ax[i // 8, i % 8]
image = np.transpose(image, (1, 2, 0))
# image = (image * train_std) + train_mean
image = np.clip(image, 0, 1)
cax.imshow(image)
plt.show()
For Training:
images, labels = next(iter(hp_loader))
images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))
for i, image in enumerate(images):
cax = ax[i // 8, i % 8]
image = np.transpose(image, (1, 2, 0))
# image = (image * train_std) + train_mean
image = np.clip(image, 0, 1)
cax.imshow(image)
plt.show()
images, labels = next(iter(val_loader))
images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))
for i, image in enumerate(images):
cax = ax[i // 8, i % 8]
image = np.transpose(image, (1, 2, 0))
# image = (image * train_std) + train_mean
image = np.clip(image, 0, 1)
cax.imshow(image)
plt.show()
images, labels = next(iter(test_loader))
images = images.numpy()[0:32]
f, ax = plt.subplots(4, 8, figsize=(15, 10))
for i, image in enumerate(images):
cax = ax[i // 8, i % 8]
image = np.transpose(image, (1, 2, 0))
# image = (image * train_std) + train_mean
image = np.clip(image, 0, 1)
cax.imshow(image)
plt.show()
It seems that at least one of the images is not a valid example of a mushroom.. The issue is that it will be likely expensive to scan them all and remove the pictures that, while valid image files, are not actually mushrooms.
Something we could do is flag any images that our model predicts with low confidence, to be manually inspected later, since manual inspection of 6000+ pictures is not scalable nor desirable, in this context.
All our dataloaders have been configured to normalize images based on the features in the training split â
def rebalance(loader, num_samples_per_class=1000):
train_dataset = loader.dataset
targets = [label for label in train_loader.dataset.targets]
class_count = torch.bincount(torch.tensor(targets))
class_weights = 1.0 / class_count.float()
print(class_weights)
weights = class_weights[torch.tensor(targets)]
print(len(weights), weights)
num_samples = num_samples_per_class * len(class_count)
rebalanced_sampler = WeightedRandomSampler(weights, num_samples)
rebalanced_dataloader = DataLoader(
dataset=loader.dataset,
batch_size=loader.batch_size,
sampler=rebalanced_sampler,
num_workers=loader.num_workers,
pin_memory=loader.pin_memory,
drop_last=loader.drop_last,
)
return rebalanced_dataloader
balanced_train_loader = rebalance(train_loader, num_samples_per_class=3000)
del train_loader
tensor([0.0043, 0.0020, 0.0014, 0.0018, 0.0042, 0.0049, 0.0010, 0.0013, 0.0050]) 4363 tensor([0.0043, 0.0043, 0.0043, ..., 0.0050, 0.0050, 0.0050])
@run
@cached_chart()
def plot_classes_after_rebalance():
all_labels = []
for _, y in balanced_train_loader:
all_labels.append(y.flatten())
all_labels = torch.cat(all_labels)
class_counts = torch.bincount(all_labels)
print(class_counts)
return sns.countplot(x=all_labels.numpy(), color=moonstone)
Loading from cache [./cached/charts/plot_classes_after_rebalance.png]
Much better than the original distribution of classes
Let's create our Convolutional Neural Network to classify our data samples.
As required, we will use a pretrained model.
Let's try to visualize the structure of this pretrained network, as it is out of the box (without customizations):
(This diagram is generated by keras, from a pre-known model of the same CNN. Good enough for a peek preview, but we will not be able to do the same with our model later).
if not os.path.exists("cached/resnet50_diagram.png"):
model = ResNet50(weights="imagenet")
plot_model(model, to_file="cached/resnet50_diagram.png", show_shapes=True, dpi=96)
Image.open("cached/resnet50_diagram.png")
The important bit is at the very end. The last 4 or 5 blocks. We will remove the 1000 neuron output and put our 9 neuron output
class MushroomClassifier(pl.LightningModule):
def __init__(self, mushroom_classes=9, lr=0.001, betas=(0.9, 0.999), eps=1e-8):
super(MushroomClassifier, self).__init__()
self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
for param in self.model.parameters():
param.requires_grad = False
cnn_codes = self.model.fc.in_features
self.model.fc = nn.Linear(cnn_codes, mushroom_classes)
self.m_acc = MulticlassAccuracy(num_classes=mushroom_classes)
self.m_prec = MulticlassPrecision(num_classes=mushroom_classes)
self.m_recall = MulticlassRecall(num_classes=mushroom_classes)
self.m_f1 = MulticlassF1Score(num_classes=mushroom_classes)
self.m_cm = MulticlassConfusionMatrix(num_classes=mushroom_classes)
self.lr = lr
self.betas = betas
self.eps = eps
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = nn.CrossEntropyLoss()(y_pred, y)
self.log("train_loss", loss)
self.log("train_accuracy", self.m_acc(y_pred, y), on_step=True, on_epoch=True)
self.log("train_precision", self.m_prec(y_pred, y), on_step=True, on_epoch=True)
self.log("train_recall", self.m_recall(y_pred, y), on_step=True, on_epoch=True)
self.log("train_f1", self.m_f1(y_pred, y), on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = nn.CrossEntropyLoss()(y_pred, y)
self.log("val_loss", loss)
self.log("val_accuracy", self.m_acc(y_pred, y), on_step=True, on_epoch=True)
self.log("val_precision", self.m_prec(y_pred, y), on_step=True, on_epoch=True)
self.log("val_recall", self.m_recall(y_pred, y), on_step=True, on_epoch=True)
self.log("val_f1", self.m_f1(y_pred, y), on_step=True, on_epoch=True)
return loss
def predict_step(self, batch, batch_idx):
x, y = batch
predicted = self.forward(x)
predicted = torch.argmax(predicted, 1)
return predicted
def test_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = nn.CrossEntropyLoss()(y_pred, y)
self.log("test_loss", loss)
self.log("test_accuracy", self.m_acc(y_pred, y), on_step=True, on_epoch=True)
self.log("test_precision", self.m_prec(y_pred, y), on_step=True, on_epoch=True)
self.log("test_recall", self.m_recall(y_pred, y), on_step=True, on_epoch=True)
self.log("test_f1", self.m_f1(y_pred, y), on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
# RMSprop(self.parameters(), lr=self.lr)
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.lr,
betas=self.betas,
eps=self.eps,
)
return optimizer
A few things to notice:
In addition to creating this model, we also want our PT trainer to work optimally:
def new_early_stopping_callback(
metric_to_monitor="val_loss",
min_change_to_consider_an_improvement=0.00,
stop_after_x_epochs_without_improvement=3,
):
return EarlyStopping(
monitor=metric_to_monitor,
min_delta=min_change_to_consider_an_improvement,
patience=stop_after_x_epochs_without_improvement,
verbose=False,
mode="min",
)
def new_checkpoint_callback(
metric_to_monitor="val_loss",
mode="min",
filename="checkpoint_resnet50",
save_top_k=10,
):
return ModelCheckpoint(
dirpath="models/training/",
save_top_k=save_top_k,
verbose=True,
auto_insert_metric_name=True,
monitor=metric_to_monitor,
mode=mode,
filename=filename,
)
mushroom_classifier_model = MushroomClassifier(mushroom_classes=9, lr=0.0001)
if run_entire_notebook():
callbacks = [new_checkpoint_callback(), new_early_stopping_callback()]
trainer = Trainer(
max_epochs=2,
callbacks=callbacks,
)
trainer.fit(
mushroom_classifier_model,
train_dataloaders=balanced_train_loader,
val_dataloaders=val_loader,
)
skipping optional operation
Let's take a look at a few interesting things:
18.4 K Trainable params
23.5 M Non-trainable params
23.5 M Total params
94.106 Total estimated model params size (MB)
These numbers seem to match our expectation
18.4K makes sense using simple math:The formula to calculate this is:
$params = (N*\text{inputNeurons} + 1*\text{Bias}) * \text{OutputNeurons}$
So:
$18.4K \approx (\text{Inp} + 1) * 9$
We expect the number of neurons on the previous layer to be something around:
${{18.4k} \over {9}} - 1 \approx \text{Inp} \approx 2043.44\overline{444}$
We can already see that this is very close to a base 2 number (2048), which is likely to be the size of the previous layer!)
Let's extract the number of neurons in the layer previous to the final output layer (the last layer that feeds into our 9 outputs) just to verify it fully:

This is the "out of the box model", this is why it shows 1000 output neurons.
To get the number of neurons before our custom last layer, we can query it directly:
mushroom_classifier_model.model.fc.in_features
2048
Or we can use torchsummary to inspect the entire CNN:
from torchsummary import summary
summary(mushroom_classifier_model.cuda(), input_size=(3, 224, 224))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 4,096
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 16,384
BatchNorm2d-12 [-1, 256, 56, 56] 512
Conv2d-13 [-1, 256, 56, 56] 16,384
BatchNorm2d-14 [-1, 256, 56, 56] 512
ReLU-15 [-1, 256, 56, 56] 0
Bottleneck-16 [-1, 256, 56, 56] 0
Conv2d-17 [-1, 64, 56, 56] 16,384
BatchNorm2d-18 [-1, 64, 56, 56] 128
ReLU-19 [-1, 64, 56, 56] 0
Conv2d-20 [-1, 64, 56, 56] 36,864
BatchNorm2d-21 [-1, 64, 56, 56] 128
ReLU-22 [-1, 64, 56, 56] 0
Conv2d-23 [-1, 256, 56, 56] 16,384
BatchNorm2d-24 [-1, 256, 56, 56] 512
ReLU-25 [-1, 256, 56, 56] 0
Bottleneck-26 [-1, 256, 56, 56] 0
Conv2d-27 [-1, 64, 56, 56] 16,384
BatchNorm2d-28 [-1, 64, 56, 56] 128
ReLU-29 [-1, 64, 56, 56] 0
Conv2d-30 [-1, 64, 56, 56] 36,864
BatchNorm2d-31 [-1, 64, 56, 56] 128
ReLU-32 [-1, 64, 56, 56] 0
Conv2d-33 [-1, 256, 56, 56] 16,384
BatchNorm2d-34 [-1, 256, 56, 56] 512
ReLU-35 [-1, 256, 56, 56] 0
Bottleneck-36 [-1, 256, 56, 56] 0
Conv2d-37 [-1, 128, 56, 56] 32,768
BatchNorm2d-38 [-1, 128, 56, 56] 256
ReLU-39 [-1, 128, 56, 56] 0
Conv2d-40 [-1, 128, 28, 28] 147,456
BatchNorm2d-41 [-1, 128, 28, 28] 256
ReLU-42 [-1, 128, 28, 28] 0
Conv2d-43 [-1, 512, 28, 28] 65,536
BatchNorm2d-44 [-1, 512, 28, 28] 1,024
Conv2d-45 [-1, 512, 28, 28] 131,072
BatchNorm2d-46 [-1, 512, 28, 28] 1,024
ReLU-47 [-1, 512, 28, 28] 0
Bottleneck-48 [-1, 512, 28, 28] 0
Conv2d-49 [-1, 128, 28, 28] 65,536
BatchNorm2d-50 [-1, 128, 28, 28] 256
ReLU-51 [-1, 128, 28, 28] 0
Conv2d-52 [-1, 128, 28, 28] 147,456
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
Conv2d-55 [-1, 512, 28, 28] 65,536
BatchNorm2d-56 [-1, 512, 28, 28] 1,024
ReLU-57 [-1, 512, 28, 28] 0
Bottleneck-58 [-1, 512, 28, 28] 0
Conv2d-59 [-1, 128, 28, 28] 65,536
BatchNorm2d-60 [-1, 128, 28, 28] 256
ReLU-61 [-1, 128, 28, 28] 0
Conv2d-62 [-1, 128, 28, 28] 147,456
BatchNorm2d-63 [-1, 128, 28, 28] 256
ReLU-64 [-1, 128, 28, 28] 0
Conv2d-65 [-1, 512, 28, 28] 65,536
BatchNorm2d-66 [-1, 512, 28, 28] 1,024
ReLU-67 [-1, 512, 28, 28] 0
Bottleneck-68 [-1, 512, 28, 28] 0
Conv2d-69 [-1, 128, 28, 28] 65,536
BatchNorm2d-70 [-1, 128, 28, 28] 256
ReLU-71 [-1, 128, 28, 28] 0
Conv2d-72 [-1, 128, 28, 28] 147,456
BatchNorm2d-73 [-1, 128, 28, 28] 256
ReLU-74 [-1, 128, 28, 28] 0
Conv2d-75 [-1, 512, 28, 28] 65,536
BatchNorm2d-76 [-1, 512, 28, 28] 1,024
ReLU-77 [-1, 512, 28, 28] 0
Bottleneck-78 [-1, 512, 28, 28] 0
Conv2d-79 [-1, 256, 28, 28] 131,072
BatchNorm2d-80 [-1, 256, 28, 28] 512
ReLU-81 [-1, 256, 28, 28] 0
Conv2d-82 [-1, 256, 14, 14] 589,824
BatchNorm2d-83 [-1, 256, 14, 14] 512
ReLU-84 [-1, 256, 14, 14] 0
Conv2d-85 [-1, 1024, 14, 14] 262,144
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048
Conv2d-87 [-1, 1024, 14, 14] 524,288
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048
ReLU-89 [-1, 1024, 14, 14] 0
Bottleneck-90 [-1, 1024, 14, 14] 0
Conv2d-91 [-1, 256, 14, 14] 262,144
BatchNorm2d-92 [-1, 256, 14, 14] 512
ReLU-93 [-1, 256, 14, 14] 0
Conv2d-94 [-1, 256, 14, 14] 589,824
BatchNorm2d-95 [-1, 256, 14, 14] 512
ReLU-96 [-1, 256, 14, 14] 0
Conv2d-97 [-1, 1024, 14, 14] 262,144
BatchNorm2d-98 [-1, 1024, 14, 14] 2,048
ReLU-99 [-1, 1024, 14, 14] 0
Bottleneck-100 [-1, 1024, 14, 14] 0
Conv2d-101 [-1, 256, 14, 14] 262,144
BatchNorm2d-102 [-1, 256, 14, 14] 512
ReLU-103 [-1, 256, 14, 14] 0
Conv2d-104 [-1, 256, 14, 14] 589,824
BatchNorm2d-105 [-1, 256, 14, 14] 512
ReLU-106 [-1, 256, 14, 14] 0
Conv2d-107 [-1, 1024, 14, 14] 262,144
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048
ReLU-109 [-1, 1024, 14, 14] 0
Bottleneck-110 [-1, 1024, 14, 14] 0
Conv2d-111 [-1, 256, 14, 14] 262,144
BatchNorm2d-112 [-1, 256, 14, 14] 512
ReLU-113 [-1, 256, 14, 14] 0
Conv2d-114 [-1, 256, 14, 14] 589,824
BatchNorm2d-115 [-1, 256, 14, 14] 512
ReLU-116 [-1, 256, 14, 14] 0
Conv2d-117 [-1, 1024, 14, 14] 262,144
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048
ReLU-119 [-1, 1024, 14, 14] 0
Bottleneck-120 [-1, 1024, 14, 14] 0
Conv2d-121 [-1, 256, 14, 14] 262,144
BatchNorm2d-122 [-1, 256, 14, 14] 512
ReLU-123 [-1, 256, 14, 14] 0
Conv2d-124 [-1, 256, 14, 14] 589,824
BatchNorm2d-125 [-1, 256, 14, 14] 512
ReLU-126 [-1, 256, 14, 14] 0
Conv2d-127 [-1, 1024, 14, 14] 262,144
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048
ReLU-129 [-1, 1024, 14, 14] 0
Bottleneck-130 [-1, 1024, 14, 14] 0
Conv2d-131 [-1, 256, 14, 14] 262,144
BatchNorm2d-132 [-1, 256, 14, 14] 512
ReLU-133 [-1, 256, 14, 14] 0
Conv2d-134 [-1, 256, 14, 14] 589,824
BatchNorm2d-135 [-1, 256, 14, 14] 512
ReLU-136 [-1, 256, 14, 14] 0
Conv2d-137 [-1, 1024, 14, 14] 262,144
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048
ReLU-139 [-1, 1024, 14, 14] 0
Bottleneck-140 [-1, 1024, 14, 14] 0
Conv2d-141 [-1, 512, 14, 14] 524,288
BatchNorm2d-142 [-1, 512, 14, 14] 1,024
ReLU-143 [-1, 512, 14, 14] 0
Conv2d-144 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-145 [-1, 512, 7, 7] 1,024
ReLU-146 [-1, 512, 7, 7] 0
Conv2d-147 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-148 [-1, 2048, 7, 7] 4,096
Conv2d-149 [-1, 2048, 7, 7] 2,097,152
BatchNorm2d-150 [-1, 2048, 7, 7] 4,096
ReLU-151 [-1, 2048, 7, 7] 0
Bottleneck-152 [-1, 2048, 7, 7] 0
Conv2d-153 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-154 [-1, 512, 7, 7] 1,024
ReLU-155 [-1, 512, 7, 7] 0
Conv2d-156 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-157 [-1, 512, 7, 7] 1,024
ReLU-158 [-1, 512, 7, 7] 0
Conv2d-159 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-160 [-1, 2048, 7, 7] 4,096
ReLU-161 [-1, 2048, 7, 7] 0
Bottleneck-162 [-1, 2048, 7, 7] 0
Conv2d-163 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-164 [-1, 512, 7, 7] 1,024
ReLU-165 [-1, 512, 7, 7] 0
Conv2d-166 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-167 [-1, 512, 7, 7] 1,024
ReLU-168 [-1, 512, 7, 7] 0
Conv2d-169 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-170 [-1, 2048, 7, 7] 4,096
ReLU-171 [-1, 2048, 7, 7] 0
Bottleneck-172 [-1, 2048, 7, 7] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0
Linear-174 [-1, 9] 18,441
ResNet-175 [-1, 9] 0
================================================================
Total params: 23,526,473
Trainable params: 18,441
Non-trainable params: 23,508,032
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.55
Params size (MB): 89.75
Estimated Total Size (MB): 376.87
----------------------------------------------------------------
The important bit is at the very end:
ReLU-171 [-1, 2048, 7, 7] 0
Bottleneck-172 [-1, 2048, 7, 7] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0 <<< these two are the
Linear-174 [-1, 9] 18,441 <<< important bits
ResNet-175 [-1, 9] 0
================================================================
Sure enough: $(2048 + 1) * 9 \approx 18,4k $ that we saw before:
result = (2048 + 1) * 9
print(result)
util.check(result == 18441)
assert result == 18441, "# of parameters does not match our calculations"
18441 â
Let's use optuna to perform some hyperparameter tuning.
For CNN's, one of the most critical hp to tune is the learning rate.
Since we want to see how things work under the hood, let's also tune a couple of extra parameters related to decay beta1 and beta2.
def optimize_hyperparams(study_name: str):
def objective(trial):
print("providing new parameters from optuna")
lr = trial.suggest_loguniform("lr", 1e-5, 1e-1)
beta1 = trial.suggest_uniform("beta1", 0.8, 1.0)
beta2 = trial.suggest_uniform("beta2", 0.9, 1.0)
newly_created_model = MushroomClassifier(
mushroom_classes=9, lr=lr, betas=(beta1, beta2)
)
trainer = Trainer(
max_epochs=60,
callbacks=[
new_checkpoint_callback(),
new_early_stopping_callback(stop_after_x_epochs_without_improvement=6),
PyTorchLightningPruningCallback(trial, monitor="val_loss"),
],
)
trainer.fit(
newly_created_model,
train_dataloaders=balanced_train_loader, # hp_loader
val_dataloaders=val_loader,
)
return trainer.callback_metrics["val_loss"].item()
optuna_utils.create_optuna_study(study_name, allow_resume=True, direction="minimize")
study = optuna_utils.get_study(study_name).optimize(objective, n_trials=100)
if run_entire_notebook():
optimize_hyperparams(study_name="mushroom_training_val")
# optimize_hyperparams(study_name="mushroom_hp_val")
skipping optional operation
@run
@cached_chart()
def loss_best_trial():
studies = ["mushroom_training_val", "mushroom_hp_val"]
f, ax = plt.subplots(1, 2, figsize=(12, 5))
f.suptitle("comparison of hp tuning depending on dataset size")
for i, study_name in enumerate(studies):
cax = ax[i]
cax.set_title(f"loss for best trial {study_name}")
cax.set_xlabel("epoch")
cax.set_ylabel("loss")
study = optuna_utils.get_study(study_name)
l = sns.lineplot(study.best_trial.intermediate_values, ax=cax, color=moonstone)
lr = study.best_params["lr"]
beta1 = study.best_params["beta1"]
beta2 = study.best_params["beta2"]
props = dict(boxstyle="round", facecolor="grey", alpha=0.15) # bbox features
cax.text(
1.03,
0.98,
f"{lr = :.4f}\n{beta1 = :.4f}\n{beta2 = :.4f}",
transform=cax.transAxes,
fontsize=12,
verticalalignment="top",
bbox=props,
)
cax.set_ylim(0, 2)
cax.set_xlim(0, 60)
plt.tight_layout()
return l
Loading from cache [./cached/charts/loss_best_trial.png]
We can see that a larger dataset size contributes to our model learning better and faster, during hyperparameter tuning.
The decay parameters beta1/beta2 also look much better for the first larger dataset, (beta2 being a lot closer to 1 than beta1) instead of the similar values that it has in the second study
Since the better tuned model benefits from the larger training dataset, we will only compare performance for that one study, and ignore the rest of optuna studies.
def loss_all_trials(study_name: str, ids=None):
plt.xlabel("epoch")
plt.ylabel("loss")
study = optuna_utils.get_study(study_name)
best_trial = study.best_trial.number
for trial in study.trials:
col = moonstone if trial.number == best_trial else "grey"
lw = 3 if trial.number == best_trial else 0.3
label = "Best Trial" if trial.number == best_trial else None
if not ids:
plt.title(f"loss for all trials [{study_name}]")
sns.lineplot(trial.intermediate_values, color=col, linewidth=lw, label=label)
elif trial.number in ids:
plt.title(f"loss for trials [{study_name}] {ids}")
sns.lineplot(
trial.intermediate_values,
label=f"trial {trial.number}",
color=col,
linewidth=lw,
)
plt.legend()
plt.ylim(0, 2.1)
return plt.gca()
@run
@cached_chart()
def loss_all_trials_using_training_split():
return loss_all_trials("mushroom_training_val")
Loading from cache [./cached/charts/loss_all_trials_using_training_split.png]
optuna_study = optuna_utils.get_study(study_name="mushroom_training_val")
studies_df = optuna_study.trials_dataframe()
studies_df = studies_df[studies_df["state"] == "COMPLETE"]
longest = studies_df.sort_values("duration", ascending=False)[:5].number
studies_of_interest = set(longest) | set([optuna_study.best_trial.number])
studies_df.loc[list(studies_of_interest)]
| number | value | datetime_start | datetime_complete | duration | params_beta1 | params_beta2 | params_lr | state | |
|---|---|---|---|---|---|---|---|---|---|
| 1 | 1 | 0.787550 | 2023-12-10 03:08:46.818755 | 2023-12-10 03:55:18.139670 | 0 days 00:46:31.320915 | 0.873872 | 0.967798 | 0.000377 | COMPLETE |
| 2 | 2 | 1.119920 | 2023-12-10 03:55:18.157443 | 2023-12-10 05:03:13.707603 | 0 days 01:07:55.550160 | 0.883608 | 0.982735 | 0.000024 | COMPLETE |
| 4 | 4 | 0.769809 | 2023-12-10 05:21:20.399222 | 2023-12-10 06:23:34.361549 | 0 days 01:02:13.962327 | 0.841713 | 0.966169 | 0.000302 | COMPLETE |
| 6 | 6 | 0.779809 | 2023-12-10 06:41:43.628434 | 2023-12-10 07:13:26.086821 | 0 days 00:31:42.458387 | 0.853146 | 0.910811 | 0.000900 | COMPLETE |
| 24 | 24 | 0.815131 | 2023-12-10 08:34:20.263166 | 2023-12-10 09:03:57.506501 | 0 days 00:29:37.243335 | 0.833433 | 0.955447 | 0.001120 | COMPLETE |
loss_all_trials("mushroom_training_val", studies_of_interest)
<AxesSubplot: title={'center': 'loss for trials [mushroom_training_val] {1, 2, 4, 6, 24}'}, xlabel='epoch', ylabel='loss'>
A second attempt was used to use a smaller dataset split for HP tuning, with the hopes of finding optimal values faster.
However, this has resulted in noticeably worse performance. None of the trials achieve anything close to the ~0.7/0.8 loss on unseen data.
loss_all_trials("mushroom_hp_val")
<AxesSubplot: title={'center': 'loss for all trials [mushroom_hp_val]'}, xlabel='epoch', ylabel='loss'>
study = optuna_utils.get_study(study_name="mushroom_training_val")
print("loss\t", study.best_value)
print("params\t", study.best_params)
loss 0.7698094248771667
params {'lr': 0.0003022396228963802, 'beta1': 0.8417125470622357, 'beta2': 0.966169418528305}
best_model = MushroomClassifier(
mushroom_classes=9,
lr=study.best_params["lr"],
betas=(study.best_params["beta1"], study.best_params["beta2"]),
)
trainer = Trainer(
max_epochs=60,
callbacks=[
new_checkpoint_callback(),
new_early_stopping_callback(stop_after_x_epochs_without_improvement=6),
],
)
trainer.fit(
best_model,
train_dataloaders=balanced_train_loader, # hp_loader
val_dataloaders=val_loader,
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/edu/anaconda3/envs/py39_lab4/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:634: UserWarning:
Checkpoint directory /home/edu/turing/projects/sprint13-mushrooms/project/models/training exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------
0 | model | ResNet | 23.5 M
1 | m_acc | MulticlassAccuracy | 0
2 | m_prec | MulticlassPrecision | 0
3 | m_recall | MulticlassRecall | 0
4 | m_f1 | MulticlassF1Score | 0
5 | m_cm | MulticlassConfusionMatrix | 0
-------------------------------------------------------
18.4 K Trainable params
23.5 M Non-trainable params
23.5 M Total params
94.106 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 0, global step 211: 'val_loss' reached 1.59601 (best 1.59601), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 1, global step 422: 'val_loss' reached 1.36843 (best 1.36843), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 2, global step 633: 'val_loss' reached 1.20773 (best 1.20773), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 3, global step 844: 'val_loss' reached 1.14454 (best 1.14454), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 4, global step 1055: 'val_loss' reached 1.10301 (best 1.10301), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 5, global step 1266: 'val_loss' reached 1.05823 (best 1.05823), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 6, global step 1477: 'val_loss' reached 1.03442 (best 1.03442), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 7, global step 1688: 'val_loss' reached 0.99475 (best 0.99475), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 8, global step 1899: 'val_loss' reached 0.95910 (best 0.95910), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 9, global step 2110: 'val_loss' reached 0.94587 (best 0.94587), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 10, global step 2321: 'val_loss' reached 0.93859 (best 0.93859), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 11, global step 2532: 'val_loss' reached 0.93307 (best 0.93307), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 12, global step 2743: 'val_loss' reached 0.90248 (best 0.90248), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 13, global step 2954: 'val_loss' reached 0.88191 (best 0.88191), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 14, global step 3165: 'val_loss' reached 0.86941 (best 0.86941), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 15, global step 3376: 'val_loss' reached 0.87006 (best 0.86941), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 16, global step 3587: 'val_loss' reached 0.85080 (best 0.85080), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 17, global step 3798: 'val_loss' reached 0.84879 (best 0.84879), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 18, global step 4009: 'val_loss' reached 0.86033 (best 0.84879), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 19, global step 4220: 'val_loss' reached 0.83894 (best 0.83894), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 20, global step 4431: 'val_loss' reached 0.83431 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 21, global step 4642: 'val_loss' reached 0.84595 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 22, global step 4853: 'val_loss' reached 0.83944 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 23, global step 5064: 'val_loss' reached 0.83701 (best 0.83431), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 24, global step 5275: 'val_loss' reached 0.83139 (best 0.83139), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 25, global step 5486: 'val_loss' reached 0.82635 (best 0.82635), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 26, global step 5697: 'val_loss' reached 0.81184 (best 0.81184), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 27, global step 5908: 'val_loss' reached 0.80840 (best 0.80840), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 28, global step 6119: 'val_loss' reached 0.80679 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 29, global step 6330: 'val_loss' reached 0.81115 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 30, global step 6541: 'val_loss' reached 0.81302 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 31, global step 6752: 'val_loss' reached 0.81441 (best 0.80679), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 32, global step 6963: 'val_loss' reached 0.79239 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 33, global step 7174: 'val_loss' reached 0.80597 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 34, global step 7385: 'val_loss' reached 0.81312 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 35, global step 7596: 'val_loss' reached 0.79496 (best 0.79239), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 36, global step 7807: 'val_loss' reached 0.77536 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v121.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 37, global step 8018: 'val_loss' reached 0.79161 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v117.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 38, global step 8229: 'val_loss' reached 0.78940 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v114.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 39, global step 8440: 'val_loss' reached 0.78410 (best 0.77536), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v120.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 40, global step 8651: 'val_loss' reached 0.77464 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v113.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 41, global step 8862: 'val_loss' reached 0.79397 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 42, global step 9073: 'val_loss' reached 0.77888 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v119.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 43, global step 9284: 'val_loss' reached 0.77926 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v112.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 44, global step 9495: 'val_loss' reached 0.78799 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v116.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 45, global step 9706: 'val_loss' reached 0.78845 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v118.ckpt' as top 10
Validation: | | 0/? [00:00<?, ?it/s]
Epoch 46, global step 9917: 'val_loss' reached 0.79021 (best 0.77464), saving model to '/home/edu/turing/projects/sprint13-mushrooms/project/models/training/checkpoint_resnet50-v115.ckpt' as top 10
While training our model, we are comfortably achieving ~0.70 loss on unseen data (val split).
đĄ Since we are using PyTorch Lightning, we don't need to manually do model.train/eval() as this is done automatically by the framework
Let's check see how well our model performs on our final test split:
trainer.test(best_model, test_loader)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing: | | 0/? [00:00<?, ?it/s]
âââââââââââââââââââââââââââââłââââââââââââââââââââââââââââ â Test metric â DataLoader 0 â âĄâââââââââââââââââââââââââââââââââââââââââââââââââââââââ⊠â test_accuracy_epoch â 0.15733768045902252 â â test_f1_epoch â 0.1711704581975937 â â test_loss â 0.8650455474853516 â â test_precision_epoch â 0.19689351320266724 â â test_recall_epoch â 0.15733768045902252 â âââââââââââââââââââââââââââââ´ââââââââââââââââââââââââââââ
[{'test_loss': 0.8650455474853516,
'test_accuracy_epoch': 0.15733768045902252,
'test_precision_epoch': 0.19689351320266724,
'test_recall_epoch': 0.15733768045902252,
'test_f1_epoch': 0.1711704581975937}]
â ď¸ remember that these are epoch based and not the final results. The only one we can take into consideration right now is test_loss
â We will look at the overall performance further down.
Also, it does not show signs of overfitting â seeing that the train/val loss was 0.80 and this on test (unseen data) is 0.86
We can use tensor_board to visualize and compare the different runs of our CNN

predicted_test = trainer.predict(best_model, test_loader)
predicted_test = torch.cat(predicted_test)
predicted_test
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting: | | 0/? [00:00<?, ?it/s]
tensor([0, 1, 0, ..., 8, 8, 8])
actual = [y for x, y in test_loader.dataset]
assert len(predicted_test) == len(actual)
print(f"{len(predicted_test) = }")
len(predicted_test) = 1007
cm = MulticlassConfusionMatrix(num_classes=9, normalize="true")
confusion_matrix = cm(predicted_test, torch.tensor(actual))
sns.heatmap(
confusion_matrix,
annot=True,
fmt="0.00%",
cmap=sns.light_palette(moonstone, as_cmap=False, n_colors=10),
)
<AxesSubplot: >
This is clearly terrible. Almost all trial perform poorly out of the box.
None of them get even close to the performance we got when we did HP tuning with the larger dataset.
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def calculate_metrics(y_true, y_pred):
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average="weighted")
recall = recall_score(y_true, y_pred, average="weighted")
f1 = f1_score(y_true, y_pred, average="weighted")
print(f"Accuracy: {accuracy:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
print(f"F1 Score: {f1:.3f}")
calculate_metrics(predicted_test, torch.tensor(actual))
Accuracy: 0.694 Precision: 0.713 Recall: 0.694 F1 Score: 0.692
It seems that PyTorch is leveraging the GPU properly during training

Let's see how this translates to our training/inference speed
The training speed for our model is around 14 items/sec

The training speed for our model is around 14 items/sec

Let's compare how this compares with inference on CPU.
Since we just want to get a rough idea about performance difference and orders of magnitude.
@run
def inference_on_cpu():
cpu_model = best_model.to("cpu")
cpu_trainer = pl.Trainer(accelerator="cpu")
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True
)
cpu_trainer.test(cpu_model, test_loader)

When using CPU, our inference speed drops down to almost 1 item/second (instead of 14/sec)
While this was not required, I wanted to try to run this model on an old Coral TPU I have laying around.


Despite the effort, I could not get past step 2. The conversion between ONNX and TF was failing consistently.
I am not worried about this, since it is beyond the scope of the requirements, but if you have experience with these devices/frameworks/standards, I'd love to get some of your insights and advice. :)
import torch
from pytorch_lightning import LightningModule
def convert_to_onnx(model: LightningModule, input_sample: torch.Tensor, onnx_path: str):
model.to_onnx(onnx_path, input_sample)
convert_to_onnx(
best_model, input_sample=torch.randn(1, 3, 224, 224), onnx_path="best_model.onnx"
)
if run_entire_notebook("onnx"):
!onnx-tf convert -i best_model.onnx -o best_model.pb
skipping optional operation ==== đď¸ printing cached output ==== 2023-12-10 10:51:49.540305: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2023-12-10 10:51:49.558228: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-10 10:51:49.558244: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-10 10:51:49.558859: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-12-10 10:51:49.561853: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-12-10 10:51:49.866110: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT /home/edu/anaconda3/envs/py39_lab4/lib/python3.9/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: TensorFlow Addons (TFA) has ended development and introduction of new features. TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024. Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). For more information see: https://github.com/tensorflow/addons/issues/2807 warnings.warn( 2023-12-10 10:51:50,734 - onnx-tf - INFO - Start converting onnx pb to tf saved model 2023-12-10 10:51:50.905293: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.920277: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.920398: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.922045: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.922130: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.922184: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.950698: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.950787: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.950852: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2023-12-10 10:51:50.950904: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 8089 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:01:00.0, compute capability: 8.6
The name of the input tensor seems to be correct and standard:
import onnx
model = onnx.load("best_model.onnx")
print([input.name for input in model.graph.input])
['input.1']
import tensorflow as tf
def quantize_and_convert_to_tflite(tf_model_path: str, tflite_model_path: str):
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
with open(tflite_model_path, "wb") as f:
f.write(tflite_quant_model)
It seems that there are some incompatibilities to convert our model to tensorflow.
Unfortunately, we ran out of time for this project and cannot pursue this further, but it was fun to try nonetheless.
In our project, we tuned a Convolutional Neural Network (CNN) for mushroom classification using PyTorch Lightning. Our goal was to accurately classify different species of mushrooms based on images, a task with significant implications in fields such as mycology and food safety.
We began by fetching a diverse dataset of mushroom images from kaggle, ensuring a wide representation of different species. We then preprocessed the images and used a data loader for efficient input into our model. To address class imbalance, we implemented a rebalancing strategy, ensuring each class had an equal chance of being represented during training.
Our CNN model was designed to use transfer learning in order to leverage an existing image classification model (resnet50). We utilized various optimization techniques, including adjusting batch sizes, number of workers in the loader, and learning rates using optuna for hyperparameter tuning. We also employed PyTorch Lightningâs advanced features like GPU acceleration and automatic differentiation, which significantly streamlined our training/tuning process.
Throughout the project, we learned valuable lessons. We found that balancing the dataset improved model performance significantly. We also learned to monitor GPU memory usage closely, as it directly impacted the training speed. Furthermore, we discovered the importance of fine-tuning hyperparameters, such as the learning rate and batch size, to optimize model performance.
We also explored and compared performance between CPU/GPU runs, and even attempted to converting the model to tensorflow to try to run it on a usb TPU. This did not work but it was interesting to learn about ONNX and how to translate models from one framework to the others.
Executive Summary
Our CNN mushroom classification model has shown promising results. The key metrics are as follows:
These metrics indicate the modelâs ability to classify mushroom species accurately.
Overall, performing reasonably well.
Should we want to focus on keeping people safe (instead of focusing on identifying the right mushroom type), we could also convert this classification task from multiclass to binary into "safe/unsafe" and we would likely get more meaningful results, but this would put this project into a more difficult terrain as it would take on more responsabilities. Right now the decision falls on the user, and so does the responsability to decide for themselves, which is what feels better/safer given the circumstances.